notebooks/Unit 9 - Model Optimization/tmo.ipynb (350 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensorflow Model Optimization Toolkit (TMO)\n", "\n", "In this notebook, we will demonstrate how to use TMO to optimize a model for deployment. We train a model on the MNIST dataset and then optimize it using TMO. We will then compare the size and accuracy of the optimized model with the original model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup TMO\n", "\n", "First, we install TMO and import the required packages." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install -q tensorflow\n", "%pip install -q tensorflow-model-optimization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_model_optimization as tfmot\n", "from tensorflow import keras\n", "import pathlib\n", "import numpy as np\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Post Training Quantization\n", "\n", "Post-training quantization tool convert weights of trained model from 32 bit to 8 bit precision. The tool convert already-trained float TensorFlow model when we convert it to TensorFlow Lite format using the [TensorFlow Lite Converter](https://www.tensorflow.org/lite/models/convert/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load MNIST dataset\n", "\n", "We load the MNIST dataset from Keras and prepare it for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load MNIST dataset\n", "mnist = keras.datasets.mnist\n", "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", "\n", "# Normalize the input image so that each pixel value is between 0 and 1.\n", "train_images = train_images / 255.0\n", "test_images = test_images / 255.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the Model\n", "\n", "Next, we define a CNN model and train it on the MNIST dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the model architecture\n", "model = keras.Sequential([\n", " keras.layers.InputLayer(input_shape=(28, 28)),\n", " keras.layers.Reshape(target_shape=(28, 28, 1)),\n", " keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Flatten(),\n", " keras.layers.Dense(10)\n", "])\n", "\n", "# Train the digit classification model\n", "model.compile(optimizer='adam',\n", " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "model.fit(\n", " train_images,\n", " train_labels,\n", " epochs=1,\n", " validation_data=(test_images, test_labels)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convert Model to TFLite\n", "\n", "After training the model, we convert it to [TFLite](https://www.tensorflow.org/lite/guide ) format and then perform quantization during the conversion." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tflite_models_dir = pathlib.Path(\"notebooks/Unit 9 - Model Optimization/models\")\n", "tflite_models_dir.mkdir(exist_ok=True, parents=True)\n", "converter = tf.lite.TFLiteConverter.from_keras_model(model)\n", "\n", "# without quantization\n", "tflite_model = converter.convert()\n", "tflite_model_file = tflite_models_dir/\"original_model.tflite\"\n", "tflite_model_file.write_bytes(tflite_model)\n", "\n", "# with quantization\n", "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", "tflite_quant_model = converter.convert()\n", "tflite_model_quant_file = tflite_models_dir/\"quantized_model.tflite\"\n", "tflite_model_quant_file.write_bytes(tflite_quant_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check Model Size\n", "\n", "The size of the quantized model is much smaller than the original model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%ls -lh {tflite_models_dir}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check Model Accuracy\n", "\n", "Next, we evaluate the accuracy of the quantized model on the test dataset and compared it with the original model.\n", "Based on the results, we can see that the accuracy of the quantized model is very close to the original model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# A helper function to evaluate the TF Lite model using \"test\" dataset.\n", "def evaluate_model(interpreter):\n", " input_index = interpreter.get_input_details()[0][\"index\"]\n", " output_index = interpreter.get_output_details()[0][\"index\"]\n", "\n", " # Run predictions on every image in the \"test\" dataset.\n", " prediction_digits = []\n", " for test_image in test_images:\n", " # Pre-processing: add batch dimension and convert to float32 to match with\n", " # the model's input data format.\n", " test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n", " interpreter.set_tensor(input_index, test_image)\n", "\n", " # Run inference.\n", " interpreter.invoke()\n", "\n", " # Post-processing: remove batch dimension and find the digit with highest\n", " # probability.\n", " output = interpreter.tensor(output_index)\n", " digit = np.argmax(output()[0])\n", " prediction_digits.append(digit)\n", "\n", " # Compare prediction results with ground truth labels to calculate accuracy.\n", " accurate_count = 0\n", " for index in range(len(prediction_digits)):\n", " if prediction_digits[index] == test_labels[index]:\n", " accurate_count += 1\n", " accuracy = accurate_count * 1.0 / len(prediction_digits)\n", "\n", " return accuracy\n", "\n", "\n", "interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))\n", "interpreter.allocate_tensors()\n", "print(\"Original model accuracy = \", evaluate_model(interpreter))\n", "\n", "\n", "interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))\n", "interpreter_quant.allocate_tensors()\n", "print(\"Quantized model accuracy = \", evaluate_model(interpreter_quant))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pruning\n", "\n", "Pruning is a technique to reduce the size of the model by removing the weights that are not important. This is determined by the magnitude of the weights. We can use pruning while training the model to reduce the size of the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude\n", "\n", "# Compute end step to finish pruning after 2 epochs.\n", "batch_size = 128\n", "epochs = 2\n", "validation_split = 0.1 # 10% of training set will be used for validation set. \n", "\n", "num_images = train_images.shape[0] * (1 - validation_split)\n", "end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs\n", "\n", "# Define model for pruning.\n", "pruning_params = {\n", " 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,\n", " final_sparsity=0.80,\n", " begin_step=0,\n", " end_step=end_step)\n", "}\n", "\n", "model_for_pruning = prune_low_magnitude(model, **pruning_params)\n", "\n", "# `prune_low_magnitude` requires a recompile.\n", "model_for_pruning.compile(optimizer='adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", "\n", "print(model_for_pruning.summary())\n", "\n", "callbacks = [\n", " tfmot.sparsity.keras.UpdatePruningStep(),\n", "]\n", "\n", "model_for_pruning.fit(train_images, train_labels,\n", " batch_size=batch_size, epochs=epochs, validation_split=validation_split,\n", " callbacks=callbacks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Accuracy\n", "\n", "We can see that the accuracy of the pruned model is very close to the original model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_, baseline_model_accuracy = model.evaluate(\n", " test_images, test_labels, verbose=0)\n", "_, model_for_pruning_accuracy = model_for_pruning.evaluate(\n", " test_images, test_labels, verbose=0)\n", "\n", "print('Baseline test accuracy:', baseline_model_accuracy) \n", "print('Pruned test accuracy:', model_for_pruning_accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Model Size\n", "\n", "Lastly, we compare the size of the pruned model with the original model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)\n", "\n", "pruning_converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)\n", "pruned_tflite_model = pruning_converter.convert()\n", "pruned_model_file = tflite_models_dir/\"pruned_model.tflite\"\n", "pruned_model_file.write_bytes(pruned_tflite_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%ls -lh {tflite_models_dir}" ] } ], "metadata": { "kernelspec": { "display_name": "model_optimization", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }